{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "Using Deep Learning to Annotate the Protein Universe.ipynb",
      "provenance": [],
      "collapsed_sections": [
        "EzXGgAYDN2Os",
        "v-6Ufto98_Ro",
        "Pca8JUJH9GyK",
        "egM-2RHJ9Bnm",
        "MYaOgeZx9JrV",
        "IR0Gl4Mx9Mgm",
        "p9iUCIO99QPs",
        "_BMjr7rr9TBt",
        "CSJ-lE1rFpFi",
        "P-e-xWYw9Uzl"
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2_PsPZCvNKtI"
      },
      "source": [
        "```\n",
        "# Copyright 2021 Google Inc.\n",
        "\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "\n",
        "#     http://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License.\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9n1hzmoE9onA"
      },
      "source": [
        "# This code supports the publication \"Using Deep Learning to Annotate the Protein Universe\".\n",
        "[preprint link](https://doi.org/10.1101/626507)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QX30QL0bIiPs"
      },
      "source": [
        "**Note**: We recommend you enable a free GPU by going:\n",
        "\n",
        "> **Runtime**   →   **Change runtime type**   →   **Hardware Accelerator: GPU**\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EzXGgAYDN2Os"
      },
      "source": [
        "# Set-up"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "v-6Ufto98_Ro"
      },
      "source": [
        "## Imports"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "P6_SeJumHUGK"
      },
      "source": [
        "import json\n",
        "import numpy as np\n",
        "import tensorflow.compat.v1 as tf\n",
        "\n",
        "# Suppress noisy log messages.\n",
        "from tensorflow.python.util import deprecation\n",
        "deprecation._PRINT_DEPRECATION_WARNINGS = False"
      ],
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Pca8JUJH9GyK"
      },
      "source": [
        "## Library functions: convert sequence to one-hot array (input to model)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Lp-3Ccx09IJ7"
      },
      "source": [
        "AMINO_ACID_VOCABULARY = [\n",
        "    'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R',\n",
        "    'S', 'T', 'V', 'W', 'Y'\n",
        "]\n",
        "def residues_to_one_hot(amino_acid_residues):\n",
        "  \"\"\"Given a sequence of amino acids, return one hot array.\n",
        "\n",
        "  Supports ambiguous amino acid characters B, Z, and X by distributing evenly\n",
        "  over possible values, e.g. an 'X' gets mapped to [.05, .05, ... , .05].\n",
        "\n",
        "  Supports rare amino acids by appropriately substituting. See\n",
        "  normalize_sequence_to_blosum_characters for more information.\n",
        "\n",
        "  Supports gaps and pads with the '.' and '-' characters; which are mapped to\n",
        "  the zero vector.\n",
        "\n",
        "  Args:\n",
        "    amino_acid_residues: string. consisting of characters from\n",
        "      AMINO_ACID_VOCABULARY\n",
        "\n",
        "  Returns:\n",
        "    A numpy array of shape (len(amino_acid_residues),\n",
        "     len(AMINO_ACID_VOCABULARY)).\n",
        "\n",
        "  Raises:\n",
        "    ValueError: if sparse_amino_acid has a character not in the vocabulary + X.\n",
        "  \"\"\"\n",
        "  to_return = []\n",
        "  normalized_residues = amino_acid_residues.replace('U', 'C').replace('O', 'X')\n",
        "  for char in normalized_residues:\n",
        "    if char in AMINO_ACID_VOCABULARY:\n",
        "      to_append = np.zeros(len(AMINO_ACID_VOCABULARY))\n",
        "      to_append[AMINO_ACID_VOCABULARY.index(char)] = 1.\n",
        "      to_return.append(to_append)\n",
        "    elif char == 'B':  # Asparagine or aspartic acid.\n",
        "      to_append = np.zeros(len(AMINO_ACID_VOCABULARY))\n",
        "      to_append[AMINO_ACID_VOCABULARY.index('D')] = .5\n",
        "      to_append[AMINO_ACID_VOCABULARY.index('N')] = .5\n",
        "      to_return.append(to_append)\n",
        "    elif char == 'Z':  # Glutamine or glutamic acid.\n",
        "      to_append = np.zeros(len(AMINO_ACID_VOCABULARY))\n",
        "      to_append[AMINO_ACID_VOCABULARY.index('E')] = .5\n",
        "      to_append[AMINO_ACID_VOCABULARY.index('Q')] = .5\n",
        "      to_return.append(to_append)\n",
        "    elif char == 'X':\n",
        "      to_return.append(\n",
        "          np.full(len(AMINO_ACID_VOCABULARY), 1. / len(AMINO_ACID_VOCABULARY)))\n",
        "    elif char == _PFAM_GAP_CHARACTER:\n",
        "      to_return.append(np.zeros(len(AMINO_ACID_VOCABULARY)))\n",
        "    else:\n",
        "      raise ValueError('Could not one-hot code character {}'.format(char))\n",
        "  return np.array(to_return)\n",
        "\n",
        "def _test_residues_to_one_hot():\n",
        "  expected = np.zeros((3, 20))\n",
        "  expected[0, 0] = 1.   # Amino acid A\n",
        "  expected[1, 1] = 1.   # Amino acid C\n",
        "  expected[2, :] = .05  # Amino acid X\n",
        "\n",
        "  actual = residues_to_one_hot('ACX')\n",
        "  np.testing.assert_allclose(actual, expected)\n",
        "_test_residues_to_one_hot()"
      ],
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FxYneKVaHxLC"
      },
      "source": [
        "def pad_one_hot_sequence(sequence: np.ndarray,\n",
        "                         target_length: int) -> np.ndarray:\n",
        "  \"\"\"Pads one hot sequence [seq_len, num_aas] in the seq_len dimension.\"\"\"\n",
        "  sequence_length = sequence.shape[0]\n",
        "  pad_length = target_length - sequence_length\n",
        "  if pad_length < 0:\n",
        "    raise ValueError(\n",
        "        'Cannot set a negative amount of padding. Sequence length was {}, target_length was {}.'\n",
        "        .format(sequence_length, target_length))\n",
        "  pad_values = [[0, pad_length], [0, 0]]\n",
        "  return np.pad(sequence, pad_values, mode='constant')\n",
        "\n",
        "def _test_pad_one_hot():\n",
        "  input_one_hot = residues_to_one_hot('ACX')\n",
        "  expected = np.array(input_one_hot.tolist() + np.zeros((4, 20)).tolist())\n",
        "  actual = pad_one_hot_sequence(input_one_hot, 7)\n",
        "\n",
        "  np.testing.assert_allclose(expected, actual)\n",
        "_test_pad_one_hot()"
      ],
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "egM-2RHJ9Bnm"
      },
      "source": [
        "## Download model and vocabulary"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Up68255TDGKe",
        "outputId": "917015bc-56ce-41b1-a0a6-f7bbff55b28a"
      },
      "source": [
        "# Get a TensorFlow SavedModel\n",
        "!wget -qN https://storage.googleapis.com/brain-genomics-public/research/proteins/pfam/models/single_domain_per_sequence_zipped_models/seed_random_32.0/5356760.tar.gz\n",
        "# unzip\n",
        "!tar xzf 5356760.tar.gz\n",
        "# Get the vocabulary for the model, which tells you which output index means which family\n",
        "!wget https://storage.googleapis.com/brain-genomics-public/research/proteins/pfam/models/single_domain_per_sequence_zipped_models/trained_model_pfam_32.0_vocab.json"
      ],
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "--2021-09-15 21:04:58--  https://storage.googleapis.com/brain-genomics-public/research/proteins/pfam/models/single_domain_per_sequence_zipped_models/trained_model_pfam_32.0_vocab.json\n",
            "Resolving storage.googleapis.com (storage.googleapis.com)... 173.194.214.128, 173.194.216.128, 173.194.217.128, ...\n",
            "Connecting to storage.googleapis.com (storage.googleapis.com)|173.194.214.128|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 197219 (193K) [application/octet-stream]\n",
            "Saving to: ‘trained_model_pfam_32.0_vocab.json’\n",
            "\n",
            "\r          trained_m   0%[                    ]       0  --.-KB/s               \rtrained_model_pfam_ 100%[===================>] 192.60K  --.-KB/s    in 0.002s  \n",
            "\n",
            "2021-09-15 21:04:58 (108 MB/s) - ‘trained_model_pfam_32.0_vocab.json’ saved [197219/197219]\n",
            "\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "XM3H7Y-rGPg8",
        "outputId": "945c3b69-589b-4a05-9746-71441e95290d"
      },
      "source": [
        "# Find the unzipped path\n",
        "!ls *5356760*"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "5356760.tar.gz\n",
            "\n",
            "trn-_cnn_random__random_sp_gpu-cnn_for_random_pfam-5356760:\n",
            "saved_model.pb\tvariables\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MYaOgeZx9JrV"
      },
      "source": [
        "## Load the model into TensorFlow"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Bp1imZTA6L_0"
      },
      "source": [
        "sess = tf.Session()\n",
        "graph = tf.Graph()"
      ],
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zbZfobaq6bkI",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "18ad7dc3-f5e6-4ed8-dc7c-c8cfda28d399"
      },
      "source": [
        "with graph.as_default():\n",
        "  saved_model = tf.saved_model.load(sess, ['serve'], 'trn-_cnn_random__random_sp_gpu-cnn_for_random_pfam-5356760')"
      ],
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "INFO:tensorflow:Restoring parameters from trn-_cnn_random__random_sp_gpu-cnn_for_random_pfam-5356760/variables/variables\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IR0Gl4Mx9Mgm"
      },
      "source": [
        "## Load tensors for class confidence prediction"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zbapL8YS7FpE"
      },
      "source": [
        "class_confidence_signature = saved_model.signature_def['confidences']\n",
        "class_confidence_signature_tensor_name = class_confidence_signature.outputs['output'].name\n",
        "\n",
        "sequence_input_tensor_name = saved_model.signature_def['confidences'].inputs['sequence'].name\n",
        "sequence_lengths_input_tensor_name = saved_model.signature_def['confidences'].inputs['sequence_length'].name"
      ],
      "execution_count": 8,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "p9iUCIO99QPs"
      },
      "source": [
        "# Predict Pfam label for domain"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "AB_fZuxm8KL8"
      },
      "source": [
        "# Run inference\n",
        "hemoglobin = 'MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKTYFPHFDLSHGSAQVKGHGKKVADALTNAVAHVDDMPNALSALSDLHAHKLRVDPVNFKLLSHCLLVTLAAHLPAEFTPAVHASLDKFLASVSTVLTSKYR'\n",
        "globin_domain = hemoglobin[6:107]  # 0-indexed, right inclusive because of the way slices in python work"
      ],
      "execution_count": 9,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "cQ1_wLxE7SeS"
      },
      "source": [
        "# If you want to put in different proteins (other than hemoglobin), you\n",
        "# can run this cell multiple times. Simply replace the variable \"hemoglobin\"\n",
        "# with your desired protein domain.\n",
        "\n",
        "# The first run of this cell will be slower; the subsequent runs will be fast.\n",
        "# This is because on the first run, the TensorFlow XLA graph is compiled, and\n",
        "# then is reused.\n",
        "with graph.as_default():\n",
        "  confidences_by_class = sess.run(\n",
        "      class_confidence_signature_tensor_name,\n",
        "      {\n",
        "          # Note that this function accepts a batch of sequences which\n",
        "          # can speed up inference when running on many sequences.\n",
        "          sequence_input_tensor_name: [residues_to_one_hot(globin_domain)],\n",
        "          sequence_lengths_input_tensor_name: [len(globin_domain)],\n",
        "      }\n",
        "  )"
      ],
      "execution_count": 10,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "erzbBt4nZjzE",
        "outputId": "49b6072e-e438-4259-f712-15f8739c63cc"
      },
      "source": [
        "np.array([residues_to_one_hot(globin_domain)]).shape"
      ],
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(1, 101, 20)"
            ]
          },
          "metadata": {},
          "execution_count": 11
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "w0D54NI18z0p",
        "outputId": "18b2d22e-9087-46fc-d88e-db30f162e015"
      },
      "source": [
        "confidences_by_class"
      ],
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "array([[3.7189863e-20, 4.8992849e-21, 3.2680125e-21, ..., 1.8260855e-19,\n",
              "        1.8322259e-19, 1.8438370e-19]], dtype=float32)"
            ]
          },
          "metadata": {},
          "execution_count": 12
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_BMjr7rr9TBt"
      },
      "source": [
        "## Map the model's prediction to a Pfam family accession"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tV6kx-97GIz6"
      },
      "source": [
        "# Load vocab\n",
        "with open('trained_model_pfam_32.0_vocab.json') as f:\n",
        "  vocab = json.loads(f.read())"
      ],
      "execution_count": 13,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "N4rw5aUzGw18",
        "outputId": "71395989-4167-482a-8a78-98cc632ad7f4"
      },
      "source": [
        "# Find what the most likely class is\n",
        "np.argmax(confidences_by_class)"
      ],
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "8505"
            ]
          },
          "metadata": {},
          "execution_count": 14
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 0
        },
        "id": "EyKlb172GyJW",
        "outputId": "0a570646-fb0a-4dd9-d461-40651bee77b7"
      },
      "source": [
        "vocab[8505] # PF00042 is family Globin"
      ],
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'PF00042'"
            ]
          },
          "metadata": {},
          "execution_count": 15
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CSJ-lE1rFpFi"
      },
      "source": [
        "## If you want to predict for a bunch of sequences, you can run inference on a batch instead of one-by-one to make it faster"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5Lnmm-4eFpFj"
      },
      "source": [
        "hemoglobin = 'MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKTYFPHFDLSHGSAQVKGHGKKVADALTNAVAHVDDMPNALSALSDLHAHKLRVDPVNFKLLSHCLLVTLAAHLPAEFTPAVHASLDKFLASVSTVLTSKYR'\n",
        "globin_domain = hemoglobin[6:107]  # 0-indexed, right inclusive because of the way slices in python work\n",
        "\n",
        "# Coronavirus spike glycoprotein S2 (PF01601)\n",
        "covid_spike_protein_domain = \"NSIAIPTNFTISVTTEILPVSMTKTSVDCTMYICGDSTECSNLLLQYGSFCTQLNRALTGIAVEQDKNTQEVFAQVKQIYKTPPIKDFGGFNFSQILPDPSKPSKRSFIEDLLFNKVTLADAGFIKQYGDCLGDIAARDLICAQKFNGLTVLPPLLTDEMIAQYTSALLAGTITSGWTFGAGAALQIPFAMQMAYRFNGIGVTQNVLYENQKLIANQFNSAIGKIQDSLSSTASALGKLQDVVNQNAQALNTLVKQLSSNFGAISSVLNDILSRLDKVEAEVQIDRLITGRLQSLQTYVTQQLIRAAEIRASANLAATKMSECVLGQSKRVDFCGKGYHLMSFPQSAPHGVVFLHVTYVPAQEKNFTTAPAICHDGKAHFPREGVFVSNGTHWFVTQRNFYEPQIITTDNTFVSGNCDVVIGIVNNTVYDPLQPELDSFKEELDKYFKNHTSPDVDLGDISGINASVVNIQKEIDRLNEVAKNLNESLIDLQELGKYEQYIKWPWYIWLGFIAGLIAIVMVTIM\""
      ],
      "execution_count": 16,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8QeV--uMIzIk"
      },
      "source": [
        "# Concatenate and pad sequence inputs\n",
        "one_hot_sequence_inputs = [\n",
        "              residues_to_one_hot(globin_domain),\n",
        "              residues_to_one_hot(covid_spike_protein_domain),\n",
        "]\n",
        "\n",
        "max_len_within_batch = max(len(globin_domain), len(covid_spike_protein_domain))\n",
        "padded_sequence_inputs = [pad_one_hot_sequence(s, max_len_within_batch)\n",
        "                          for s in one_hot_sequence_inputs]"
      ],
      "execution_count": 17,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EpquKRJoFpFj"
      },
      "source": [
        "# The first run of this cell will be slower; the subsequent runs will be fast.\n",
        "# This is because on the first run, the TensorFlow XLA graph is compiled, and\n",
        "# then is reused.\n",
        "with graph.as_default():\n",
        "  confidences_by_class = sess.run(\n",
        "      class_confidence_signature_tensor_name,\n",
        "      {\n",
        "          sequence_input_tensor_name: padded_sequence_inputs,\n",
        "          sequence_lengths_input_tensor_name: [\n",
        "              len(globin_domain),\n",
        "              len(covid_spike_protein_domain)\n",
        "          ],\n",
        "      })"
      ],
      "execution_count": 18,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 0
        },
        "id": "uO7jvH-dFpFk",
        "outputId": "54f210e6-7d24-4e76-e63a-54a6fe096de9"
      },
      "source": [
        "vocab[np.argmax(confidences_by_class[0])] # 0th element is for hemoglobin; PF00042 is family Globin"
      ],
      "execution_count": 19,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'PF00042'"
            ]
          },
          "metadata": {},
          "execution_count": 19
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 0
        },
        "id": "cNNAyMa9Jhlp",
        "outputId": "05da5d34-37fd-4eb1-ef20-2aaf7730f34c"
      },
      "source": [
        "vocab[np.argmax(confidences_by_class[1])] # 1th element is for covid; PF01601 is Coronavirus spike glycoprotein S2"
      ],
      "execution_count": 20,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'PF01601'"
            ]
          },
          "metadata": {},
          "execution_count": 20
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "P-e-xWYw9Uzl"
      },
      "source": [
        "# Compute embedding of domain"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_FRCNv3G9-hw"
      },
      "source": [
        "embedding_signature = saved_model.signature_def['pooled_representation']\n",
        "embedding_signature_tensor_name = embedding_signature.outputs['output'].name"
      ],
      "execution_count": 21,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Dfh4JW3x9-h9"
      },
      "source": [
        "# The first run of this cell will be slower; the subsequent runs will be fast.\n",
        "# This is because on the first run, the TensorFlow XLA graph is compiled, and\n",
        "# then is reused.\n",
        "with graph.as_default():\n",
        "  embedding = sess.run(\n",
        "      embedding_signature_tensor_name,\n",
        "      {\n",
        "          # Note that this function accepts a batch of sequences which\n",
        "          # can speed up inference when running on many sequences.\n",
        "          sequence_input_tensor_name: [residues_to_one_hot(globin_domain)],\n",
        "          sequence_lengths_input_tensor_name: [len(globin_domain)],\n",
        "      }\n",
        "  )"
      ],
      "execution_count": 22,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "fJRKUEwI89JR",
        "outputId": "5a2e7d15-2d09-40ec-85e8-904206c84a47"
      },
      "source": [
        "# Shape of embedding is (# seqs in batch, number of features in embedding space)\n",
        "embedding.shape"
      ],
      "execution_count": 23,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(1, 1100)"
            ]
          },
          "metadata": {},
          "execution_count": 23
        }
      ]
    }
  ]
}